import torch
import torch.optim as optim
import torch
import sympy as sp
from optimizer.expl_sgd import esgd
from optimizer.expl_adam import eadam

def get_optimizers(args):
    optimizer_dict = {
        "sgd": optim.SGD,
        "adam": optim.Adam,
        "rmsprop": optim.RMSprop,
        "adamW": optim.AdamW,
        "adagrad": optim.Adagrad,
        "adadelta": optim.Adadelta,
        "esgd": esgd,
        "eadam": eadam
    }

    optimizers = []


    for optimizer in args.optimizers:
        name = optimizer["name"]
        params = optimizer["params"]
        if name not in optimizer_dict:
            raise ValueError(f"Optimizer '{name}' is not supported. "
                             f"Supported optimizers are: {list(optimizer_dict.keys())}")

        optimizer_class = optimizer_dict[name]
        optimizer = optimizer_class([torch.tensor([value], dtype=torch.float32, requires_grad=True)
    for value in args.start_point], **params)
        optimizers.append(optimizer)
    
    return optimizers


def get_function(args):
    expr = sp.sympify(args.function)
    free_symbols = sorted(expr.free_symbols, key=lambda s: str(s))

    # 转换为 PyTorch 函数
    lambdified = sp.lambdify(free_symbols, expr, modules={"sin": torch.sin, "log": torch.log, "sqrt": torch.sqrt, "exp": torch.exp, "torch": torch})
    
    # 根据符号数量动态返回函数
    if len(free_symbols) == 1:
        def torch_function(x):
            return lambdified(x)
    else:
        def torch_function(*args):
            if len(args) != len(free_symbols):
                raise ValueError(f"Expected {len(free_symbols)} arguments, but got {len(args)}.")
            return lambdified(*args)
    
    return torch_function


